{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "##\n",
    "import os\n",
    "import glob\n",
    "import numpy as np\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_train(path):\n",
    "    with open(path, 'r') as file:\n",
    "        lines = file.readlines()\n",
    "    return lines\n",
    "\n",
    "# Count alignments of target tokens conditioned on each source token\n",
    "def count_align(src_list, tgt_list, align_list, align_dict):\n",
    "    for al in align_list:\n",
    "        als, alt = al.split('-')\n",
    "        als, alt = int(als), int(alt)\n",
    "        s_tok, t_tok = src_list[als], tgt_list[alt]\n",
    "        if s_tok not in align_dict.keys():\n",
    "            align_dict[s_tok] = dict()\n",
    "        if t_tok not in align_dict[s_tok].keys():\n",
    "            align_dict[s_tok][t_tok] = 1\n",
    "        else:\n",
    "            align_dict[s_tok][t_tok] += 1\n",
    "    return align_dict\n",
    "\n",
    "# Compute uncertainty by the align_dict\n",
    "def comp_uncertainty(align_dict):\n",
    "    H = []\n",
    "    for k,v in align_dict.items():\n",
    "        tot_k = sum(v.values())    # total counts of alignment to word k\n",
    "        h = 0\n",
    "        for kk,vv in v.items():\n",
    "            p = vv / float(tot_k)  # p for k-kk alignment\n",
    "            h = h - p * np.log(p)\n",
    "        H.append(h)\n",
    "    return H"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "disk = \"/apdcephfs/share_916081/joelwxjiao\"\n",
    "align_path = disk + \"/significance-test/linguistic\"\n",
    "data_path = \"/wmt14_en_de_base_reju\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4021956 4021956 4021956\n"
     ]
    }
   ],
   "source": [
    "# Readin data\n",
    "path_src = align_path + data_path + \"/train_active.en\"\n",
    "path_tgt = align_path + data_path + \"/train_active.de\"\n",
    "path_ali = align_path + data_path + \"/train_active.alignment\"\n",
    "\n",
    "src_data = read_train(path_src)\n",
    "tgt_data = read_train(path_tgt)\n",
    "ali_data = read_train(path_ali)\n",
    "print(len(src_data), len(tgt_data), len(ali_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4021956/4021956 [02:44<00:00, 24479.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32611\n",
      "2.005513752429398 [2.344597574823132, 2.4896231390260826, 3.78130818645039, 4.154762887225315, 3.128340391782482, 3.69410653646968, 3.9163537121584553, 3.627337110539781, 1.5973391948050726, 3.556133259237476]\n"
     ]
    }
   ],
   "source": [
    "align_dict = dict()\n",
    "for idx in tqdm(range(len(src_data))):\n",
    "    src = src_data[idx].strip('\\n').split()\n",
    "    tgt = tgt_data[idx].strip('\\n').split()\n",
    "    ali = ali_data[idx].strip('\\n').split()\n",
    "    align_dict = count_align(src, tgt, ali, align_dict)\n",
    "\n",
    "print(len(align_dict))\n",
    "\n",
    "H = comp_uncertainty(align_dict)\n",
    "\n",
    "print(np.mean(H), H[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
